{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sCW-xvckaO8s"
   },
   "source": [
    "# Live Colab Example\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-12T09:39:54.418700Z",
     "start_time": "2020-09-12T09:39:54.415303Z"
    },
    "id": "16gWjoEUXOhl"
   },
   "source": [
    "## Dependencies and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "jc7ZqfooYZnD"
   },
   "outputs": [],
   "source": [
    "#@title Install dependencies\n",
    "\n",
    "!pip install -q omegaconf torchaudio pydub\n",
    "\n",
    "import os\n",
    "from os.path import exists\n",
    "\n",
    "if not exists('silero-models'):\n",
    "  !git clone -q --depth 1 https://github.com/snakers4/silero-models\n",
    "\n",
    "%cd silero-models\n",
    "\n",
    "# silero imports\n",
    "import torch\n",
    "import random\n",
    "from glob import glob\n",
    "from omegaconf import OmegaConf\n",
    "from src.silero.utils import (init_jit_model, \n",
    "                       split_into_batches,\n",
    "                       read_audio,\n",
    "                       read_batch,\n",
    "                       prepare_model_input)\n",
    "from colab_utils import (record_audio,\n",
    "                         audio_bytes_to_np,\n",
    "                         upload_audio)\n",
    "\n",
    "device = torch.device('cpu')   # you can use any pytorch device\n",
    "models = OmegaConf.load('models.yml')\n",
    "\n",
    "# imports for uploading/recording\n",
    "import numpy as np\n",
    "import ipywidgets as widgets\n",
    "from scipy.io import wavfile\n",
    "from IPython.display import Audio, display, clear_output\n",
    "from torchaudio.functional import vad\n",
    "\n",
    "\n",
    "# wav to text method\n",
    "def wav_to_text(f='test.wav'):\n",
    "  batch = read_batch([f])\n",
    "  input = prepare_model_input(batch, device=device)\n",
    "  output = model(input)\n",
    "  return decoder(output[0].cpu())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_6QIeg7XffsO"
   },
   "source": [
    "## Transcribe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "NJAOV1bbhEv0"
   },
   "outputs": [],
   "source": [
    "#@markdown { run: \"auto\" }\n",
    "\n",
    "language = \"English\" #@param [\"English\", \"German\", \"Spanish\"]\n",
    "\n",
    "print(language)\n",
    "if language == 'German':\n",
    "  model, decoder = init_jit_model(models.stt_models.de.latest.jit, device=device)\n",
    "elif language == \"Spanish\":\n",
    "  model, decoder = init_jit_model(models.stt_models.es.latest.jit, device=device)\n",
    "else:\n",
    "  model, decoder = init_jit_model(models.stt_models.en.latest.jit, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "FYsz_90gTQh-"
   },
   "outputs": [],
   "source": [
    "#@markdown { run: \"auto\" }\n",
    "\n",
    "use_VAD = \"No\" #@param [\"Yes\", \"No\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "QttWasy5hUd6"
   },
   "outputs": [],
   "source": [
    "#@markdown Either record audio from microphone or upload audio from file (.mp3 or .wav) { run: \"auto\" }\n",
    "\n",
    "record_or_upload = \"Record\" #@param [\"Record\", \"Upload (.mp3 or .wav)\"]\n",
    "record_seconds =   4#@param {type:\"number\", min:1, max:10, step:1}\n",
    "sample_rate = 16000\n",
    "\n",
    "def _apply_vad(audio, boot_time=0, trigger_level=9, **kwargs):\n",
    "  print('\\nVAD applied\\n')\n",
    "  vad_kwargs = dict(locals().copy(), **kwargs)\n",
    "  vad_kwargs['sample_rate'] = sample_rate\n",
    "  del vad_kwargs['kwargs'], vad_kwargs['audio']\n",
    "  audio = vad(torch.flip(audio, ([0])), **vad_kwargs)\n",
    "  return vad(torch.flip(audio, ([0])), **vad_kwargs)\n",
    "\n",
    "def _recognize(audio):\n",
    "  display(Audio(audio, rate=sample_rate, autoplay=True))\n",
    "  if use_VAD == \"Yes\":\n",
    "    audio = _apply_vad(audio)\n",
    "  wavfile.write('test.wav', sample_rate, (32767*audio).numpy().astype(np.int16))\n",
    "  transcription = wav_to_text()\n",
    "  print('\\n\\nTRANSCRIPTION:\\n')\n",
    "  print(transcription)\n",
    "\n",
    "def _record_audio(b):\n",
    "  clear_output()\n",
    "  audio = record_audio(record_seconds)\n",
    "  wavfile.write('recorded.wav', sample_rate, (32767*audio).numpy().astype(np.int16))\n",
    "  _recognize(audio)\n",
    "\n",
    "def _upload_audio(b):\n",
    "  clear_output()\n",
    "  audio = upload_audio()\n",
    "  _recognize(audio)\n",
    "  return audio\n",
    "\n",
    "if record_or_upload == \"Record\":\n",
    "  button = widgets.Button(description=\"Record Speech\")\n",
    "  button.on_click(_record_audio)\n",
    "  display(button)\n",
    "else:\n",
    "  audio = _upload_audio(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "E-bFGpn_TQiW"
   },
   "outputs": [],
   "source": [
    "#@markdown Check audio after applying VAD { run: \"auto\" }\n",
    "\n",
    "if record_or_upload == \"Record\":\n",
    "  audio = read_audio('recorded.wav', sample_rate)\n",
    "display(Audio(_apply_vad(audio), rate=sample_rate, autoplay=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-11T13:31:58.954518Z",
     "start_time": "2020-09-11T13:31:58.952259Z"
    },
    "id": "nMkcU8sDXOh8"
   },
   "source": [
    "# PyTorch Example\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "xj7emOprcPQ6"
   },
   "outputs": [],
   "source": [
    "#@title Install Dependencies\n",
    "\n",
    "# this assumes that you have a relevant version of PyTorch installed\n",
    "!pip install -q torchaudio omegaconf\n",
    "\n",
    "import os\n",
    "from os.path import exists\n",
    "\n",
    "if not exists('silero-models'):\n",
    "  !git clone -q --depth 1 https://github.com/snakers4/silero-models\n",
    "\n",
    "%cd silero-models\n",
    "\n",
    "import torch\n",
    "import random\n",
    "from glob import glob\n",
    "from omegaconf import OmegaConf\n",
    "from src.silero.utils import (init_jit_model, \n",
    "                       split_into_batches,\n",
    "                       read_batch,\n",
    "                       prepare_model_input)\n",
    "from IPython.display import display, Audio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nYG1dBgBDN5S"
   },
   "source": [
    "## Minimal example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jwKw-yMpDQTL"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import zipfile\n",
    "import torchaudio\n",
    "from glob import glob\n",
    "\n",
    "device = torch.device('cpu')  # gpu also works, but our models are fast enough for CPU\n",
    "model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',\n",
    "                                       model='silero_stt',\n",
    "                                       jit_model='jit_xlarge',\n",
    "                                       language='en', # also available 'de', 'es'\n",
    "                                       device=device)\n",
    "(read_batch, split_into_batches,\n",
    " read_audio, prepare_model_input) = utils  # see function signature for details\n",
    "\n",
    "# download a single file, any format compatible with TorchAudio\n",
    "torch.hub.download_url_to_file('https://opus-codec.org/static/examples/samples/speech_orig.wav',\n",
    "                               dst ='speech_orig.wav', progress=True)\n",
    "test_files = glob('speech_orig.wav') \n",
    "batches = split_into_batches(test_files, batch_size=10)\n",
    "input = prepare_model_input(read_batch(batches[0]),\n",
    "                            device=device)\n",
    "\n",
    "output = model(input)\n",
    "for example in output:\n",
    "    print(decoder(example.cpu()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8r3DW7IgkJil"
   },
   "source": [
    "## More examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-11T14:21:25.234818Z",
     "start_time": "2020-09-11T14:21:25.218179Z"
    },
    "id": "GE0S5kmdXOiG"
   },
   "outputs": [],
   "source": [
    "models = OmegaConf.load('models.yml')  # all available models are listed in the yml file\n",
    "print(list(models.stt_models.keys()),\n",
    "      list(models.stt_models.en.keys()),\n",
    "      list(models.stt_models.en.latest.keys()),\n",
    "      models.stt_models.en.latest.jit)\n",
    "device = torch.device('cpu')   # you can use any pytorch device\n",
    "model, decoder = init_jit_model(models.stt_models.en.latest.jit, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-11T14:21:26.056045Z",
     "start_time": "2020-09-11T14:21:26.040771Z"
    },
    "id": "GSUsZ7cqXOiL"
   },
   "outputs": [],
   "source": [
    "device = torch.device('cpu')   # you can use any pytorch device\n",
    "model, decoder = init_jit_model(models.stt_models.en.latest.jit, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-11T14:25:14.996913Z",
     "start_time": "2020-09-11T14:21:40.831866Z"
    },
    "id": "paW8mugZXOiP"
   },
   "outputs": [],
   "source": [
    "test_files = glob('*.wav')  # replace with your data\n",
    "batches = split_into_batches(test_files, batch_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-11T13:57:09.061692Z",
     "start_time": "2020-09-11T13:57:08.992493Z"
    },
    "id": "JryVNe5hXOiR"
   },
   "outputs": [],
   "source": [
    "# transcribe a set of files\n",
    "input = prepare_model_input(read_batch(random.sample(batches, k=1)[0]),\n",
    "                            device=device)\n",
    "output = model(input)\n",
    "for example in output:\n",
    "    print(decoder(example.cpu()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-11T14:38:32.972790Z",
     "start_time": "2020-09-11T14:38:31.605231Z"
    },
    "id": "FgvdCQRSXOiY"
   },
   "outputs": [],
   "source": [
    "# listen to one file\n",
    "batch = read_batch(random.sample(batches, k=1)[0])\n",
    "input = prepare_model_input(batch,\n",
    "                            device=device)\n",
    "output = model(input)\n",
    "\n",
    "for i, example in enumerate(output):\n",
    "    print(decoder(example.cpu()))\n",
    "    display(Audio(batch[i], rate=16000))  # audio was resampled to 16kHz\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zi4m9qp4X4ma"
   },
   "outputs": [],
   "source": [
    "# align example\n",
    "batch = read_batch(random.sample(batches, k=1)[0])\n",
    "input = prepare_model_input(batch,\n",
    "                            device=device)\n",
    "\n",
    "wav_len = input.shape[1] / 16000\n",
    "\n",
    "output = model(input)\n",
    "\n",
    "for i, example in enumerate(output):\n",
    "    print(decoder(example.cpu(), wav_len, word_align=True)[-1])\n",
    "    display(Audio(batch[i], rate=16000))  # audio was resampled to 16kHz\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lfa1Za1JUgUw"
   },
   "source": [
    "# ONNX Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "ku78lggJUm_3"
   },
   "outputs": [],
   "source": [
    "#@title Install and Import Dependencies\n",
    "\n",
    "# this assumes that you have a relevant version of PyTorch installed\n",
    "!pip install -q torchaudio omegaconf onnx onnxruntime"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i72mHSBbaG3p"
   },
   "source": [
    "## Minimal example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ed20MdPEEt3C"
   },
   "outputs": [],
   "source": [
    "import onnx\n",
    "import torch\n",
    "import onnxruntime\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "language = 'en' # also available 'de', 'es'\n",
    "\n",
    "# load provided utils\n",
    "_, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_stt', language=language)\n",
    "(read_batch, split_into_batches,\n",
    " read_audio, prepare_model_input) = utils\n",
    "\n",
    " # see available models\n",
    "torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml', 'models.yml')\n",
    "models = OmegaConf.load('models.yml')\n",
    "available_languages = list(models.stt_models.keys())\n",
    "assert language in available_languages\n",
    "\n",
    "# load the actual ONNX model\n",
    "torch.hub.download_url_to_file(models.stt_models.en.latest.onnx, 'model.onnx', progress=True)\n",
    "onnx_model = onnx.load('model.onnx')\n",
    "onnx.checker.check_model(onnx_model)\n",
    "ort_session = onnxruntime.InferenceSession('model.onnx')\n",
    "\n",
    "# download a single file, any format compatible with TorchAudio\n",
    "torch.hub.download_url_to_file('https://opus-codec.org/static/examples/samples/speech_orig.wav', dst ='speech_orig.wav', progress=True)\n",
    "test_files = ['speech_orig.wav']\n",
    "batches = split_into_batches(test_files, batch_size=10)\n",
    "input = prepare_model_input(read_batch(batches[0]))\n",
    "\n",
    "# actual onnx inference and decoding\n",
    "onnx_input = input.detach().cpu().numpy()\n",
    "ort_inputs = {'input': onnx_input}\n",
    "ort_outs = ort_session.run(None, ort_inputs)\n",
    "decoded = decoder(torch.Tensor(ort_outs[0])[0])\n",
    "print(decoded)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "colab_examples.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
